{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Monotonic Output Regression\n",
    "\n",
    "This notebook accompanies the paper [Learning Convex Optimization Models](https://web.stanford.edu/~boyd/papers/learning_copt_models.html)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "('1.1.0a4', '0.1.3')"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import torch\n",
    "import numpy as np\n",
    "import cvxpy as cp\n",
    "from cvxpylayers.torch import CvxpyLayer\n",
    "import cvxpylayers\n",
    "from algorithms import fit\n",
    "import matplotlib.pyplot as plt\n",
    "from latexify import latexify\n",
    "cp.__version__, cvxpylayers.__version__"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define convex optimization model\n",
    "n = 20\n",
    "m = 10\n",
    "\n",
    "y = cp.Variable(m)\n",
    "yhat = cp.Parameter(m)\n",
    "\n",
    "objective = cp.norm2(y-yhat)\n",
    "constraints = [cp.diff(y) >= 0]\n",
    "prob = cp.Problem(cp.Minimize(objective), constraints)\n",
    "layer = CvxpyLayer(prob, [yhat], [y])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get data\n",
    "def get_data(N, n, m, theta):\n",
    "    X = torch.randn(N, n)\n",
    "    Y = layer(X @ theta + torch.randn(N, m))[0]\n",
    "    return X, Y\n",
    "\n",
    "torch.manual_seed(0)\n",
    "theta_true = torch.randn(n, m)\n",
    "X, Y = get_data(100, n, m, theta_true)\n",
    "Xval, Yval = get_data(50, n, m, theta_true)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "mse_loss = torch.nn.MSELoss()\n",
    "theta_lstsq = torch.solve(X.t() @ Y, X.t() @ X).solution\n",
    "lstsq_val_loss = mse_loss(Xval @ theta_lstsq, Yval).item()\n",
    "bayes_val_loss = mse_loss(layer(Xval @ theta_true)[0], Yval).item()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "theta = torch.zeros_like(theta_lstsq)\n",
    "theta.requires_grad_(True)\n",
    "def loss(X, Y, theta):\n",
    "    return mse_loss(layer(X @ theta)[0], Y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "001 | 6.37966\n",
      "batch 001 / 007 | 6.54562\n",
      "batch 002 / 007 | 7.21317\n",
      "batch 003 / 007 | 6.77889\n",
      "batch 004 / 007 | 6.75470\n",
      "batch 005 / 007 | 5.97233\n",
      "batch 006 / 007 | 5.50823\n",
      "batch 007 / 007 | 5.41007\n",
      "002 | 3.67326\n",
      "batch 001 / 007 | 2.82474\n",
      "batch 002 / 007 | 3.14247\n",
      "batch 003 / 007 | 2.68744\n",
      "batch 004 / 007 | 2.79289\n",
      "batch 005 / 007 | 2.84320\n",
      "batch 006 / 007 | 2.73215\n",
      "batch 007 / 007 | 2.60307\n",
      "003 | 2.77888\n",
      "batch 001 / 007 | 1.78423\n",
      "batch 002 / 007 | 1.69969\n",
      "batch 003 / 007 | 1.88519\n",
      "batch 004 / 007 | 1.79895\n",
      "batch 005 / 007 | 1.76285\n",
      "batch 006 / 007 | 1.71602\n",
      "batch 007 / 007 | 1.83642\n",
      "004 | 2.20847\n",
      "batch 001 / 007 | 1.27333\n",
      "batch 002 / 007 | 1.44328\n",
      "batch 003 / 007 | 1.33678\n",
      "batch 004 / 007 | 1.29789\n",
      "batch 005 / 007 | 1.30782\n",
      "batch 006 / 007 | 1.31228\n",
      "batch 007 / 007 | 1.23183\n",
      "005 | 1.79820\n",
      "batch 001 / 007 | 0.71949\n",
      "batch 002 / 007 | 0.82146\n",
      "batch 003 / 007 | 0.96098\n",
      "batch 004 / 007 | 0.94970\n",
      "batch 005 / 007 | 0.92515\n",
      "batch 006 / 007 | 0.97964\n",
      "batch 007 / 007 | 1.09865\n",
      "006 | 1.43201\n",
      "batch 001 / 007 | 0.36612\n",
      "batch 002 / 007 | 0.49633\n",
      "batch 003 / 007 | 0.52536\n",
      "batch 004 / 007 | 0.68597\n",
      "batch 005 / 007 | 0.74700\n",
      "batch 006 / 007 | 0.76915\n",
      "batch 007 / 007 | 0.83172\n",
      "007 | 1.16918\n",
      "batch 001 / 007 | 0.37163\n",
      "batch 002 / 007 | 0.53301\n",
      "batch 003 / 007 | 0.58942\n",
      "batch 004 / 007 | 0.63419\n",
      "batch 005 / 007 | 0.65475\n",
      "batch 006 / 007 | 0.66428\n",
      "batch 007 / 007 | 0.65300\n",
      "008 | 1.00531\n",
      "batch 001 / 007 | 0.50027\n",
      "batch 002 / 007 | 0.50870\n",
      "batch 003 / 007 | 0.55201\n",
      "batch 004 / 007 | 0.51065\n",
      "batch 005 / 007 | 0.54521\n",
      "batch 006 / 007 | 0.54851\n",
      "batch 007 / 007 | 0.52301\n",
      "009 | 0.96045\n",
      "batch 001 / 007 | 0.41113\n",
      "batch 002 / 007 | 0.34671\n",
      "batch 003 / 007 | 0.44695\n",
      "batch 004 / 007 | 0.38751\n",
      "batch 005 / 007 | 0.42404\n",
      "batch 006 / 007 | 0.47286\n",
      "batch 007 / 007 | 0.44319\n",
      "010 | 0.89814\n",
      "batch 001 / 007 | 0.49001\n",
      "batch 002 / 007 | 0.56290\n",
      "batch 003 / 007 | 0.50579\n",
      "batch 004 / 007 | 0.46224\n",
      "batch 005 / 007 | 0.42447\n",
      "batch 006 / 007 | 0.40952\n",
      "batch 007 / 007 | 0.49495\n",
      "011 | 0.82211\n",
      "batch 001 / 007 | 0.39616\n",
      "batch 002 / 007 | 0.35220\n",
      "batch 003 / 007 | 0.35311\n",
      "batch 004 / 007 | 0.35772\n",
      "batch 005 / 007 | 0.37887\n",
      "batch 006 / 007 | 0.38523\n",
      "batch 007 / 007 | 0.44267\n",
      "012 | 0.76901\n",
      "batch 001 / 007 | 0.20815\n",
      "batch 002 / 007 | 0.27799\n",
      "batch 003 / 007 | 0.30103\n",
      "batch 004 / 007 | 0.35757\n",
      "batch 005 / 007 | 0.38544\n",
      "batch 006 / 007 | 0.40748\n",
      "batch 007 / 007 | 0.39222\n",
      "013 | 0.71757\n",
      "batch 001 / 007 | 0.35951\n",
      "batch 002 / 007 | 0.40466\n",
      "batch 003 / 007 | 0.39934\n",
      "batch 004 / 007 | 0.39784\n",
      "batch 005 / 007 | 0.38302\n",
      "batch 006 / 007 | 0.36658\n",
      "batch 007 / 007 | 0.41987\n",
      "014 | 0.66855\n",
      "batch 001 / 007 | 0.58453\n",
      "batch 002 / 007 | 0.40711\n",
      "batch 003 / 007 | 0.40183\n",
      "batch 004 / 007 | 0.36751\n",
      "batch 005 / 007 | 0.36644\n",
      "batch 006 / 007 | 0.35535\n",
      "batch 007 / 007 | 0.34039\n",
      "015 | 0.65329\n",
      "batch 001 / 007 | 0.33166\n",
      "batch 002 / 007 | 0.32293\n",
      "batch 003 / 007 | 0.31314\n",
      "batch 004 / 007 | 0.31500\n",
      "batch 005 / 007 | 0.33010\n",
      "batch 006 / 007 | 0.34405\n",
      "batch 007 / 007 | 0.32332\n",
      "016 | 0.60960\n",
      "batch 001 / 007 | 0.28471\n",
      "batch 002 / 007 | 0.28110\n",
      "batch 003 / 007 | 0.27468\n",
      "batch 004 / 007 | 0.30219\n",
      "batch 005 / 007 | 0.31693\n",
      "batch 006 / 007 | 0.31605\n",
      "batch 007 / 007 | 0.35820\n",
      "017 | 0.62811\n",
      "batch 001 / 007 | 0.28251\n",
      "batch 002 / 007 | 0.24609\n",
      "batch 003 / 007 | 0.26830\n",
      "batch 004 / 007 | 0.26450\n",
      "batch 005 / 007 | 0.32891\n",
      "batch 006 / 007 | 0.33223\n",
      "batch 007 / 007 | 0.33829\n",
      "018 | 0.65671\n",
      "batch 001 / 007 | 0.22490\n",
      "batch 002 / 007 | 0.23356\n",
      "batch 003 / 007 | 0.25540\n",
      "batch 004 / 007 | 0.29610\n",
      "batch 005 / 007 | 0.31102\n",
      "batch 006 / 007 | 0.32254\n",
      "batch 007 / 007 | 0.32085\n",
      "019 | 0.60048\n",
      "batch 001 / 007 | 0.31201\n",
      "batch 002 / 007 | 0.28156\n",
      "batch 003 / 007 | 0.29576\n",
      "batch 004 / 007 | 0.34484\n",
      "batch 005 / 007 | 0.33622\n",
      "batch 006 / 007 | 0.32448\n",
      "batch 007 / 007 | 0.33567\n",
      "020 | 0.56205\n",
      "batch 001 / 007 | 0.27252\n",
      "batch 002 / 007 | 0.33209\n",
      "batch 003 / 007 | 0.36123\n",
      "batch 004 / 007 | 0.34777\n",
      "batch 005 / 007 | 0.33455\n",
      "batch 006 / 007 | 0.33327\n",
      "batch 007 / 007 | 0.33440\n"
     ]
    }
   ],
   "source": [
    "val_losses, train_losses = fit(lambda X, Y: loss(X, Y, theta), [theta], X, Y, Xval, Yval,\n",
    "                               opt=torch.optim.Adam, opt_kwargs={\"lr\": 1e-1},\n",
    "                               batch_size=16, epochs=20, verbose=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYgAAAB+CAYAAAA+/4aWAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAgAElEQVR4nO3deVhU193A8e8BBFFABLRi0Cg2qGlUZFHQuDTBplp920bEpLYJ1QrGPGl9425iH6tJrKk21jQaxqaxj2msgnnbaqPGsbWJAayImtRYTTqaKERRZFQ2Wea8f8wSlmEbZuYOcD7Pcx/HOffec7hw5zf3rEJKiaIoiqI05KV1ARRFURTPpAKEoiiKYpcKEIqiKIpdKkAoiqIodqkAoSiKotjlo1XGYWFhctCgQVplryicPHnyhpSyj1b5q3tA8QTN3QeaBYhBgwaRl5enVfaKghDicy3zV/eA4gmauw9UFZOiKIpil0cFCCklOTk5nD59WuuiKIqidHmaVTE1JTk5mcmTJ/PHP/5R66IodVRXV3PlyhUqKyu1LorDwsPDCQ4O1roYTVLXWHG2iooKunfvjhDCoeM9KkAIIUhMTCQ3N1froigNXLlyhcDAQAYNGuTwH5uWKioqKCgo8OgPL3WNFWdbu3YtO3bs4PLly/j4tP3j3qOqmAASEhIwGAwUFRVpXRSljsrKSkJDQzvkBxdA9+7dqa6u1roYzVLXWHG27OxsBg4c6FBwACcHCCFEmhAiSQiR7Og5EhISADh+/LjTyqU4R0f94IKOU/aOUk57OnLZO6Oqqir+9a9/MX78eIfP4bQAYQkKBimlXkqZ5eh5YmJi8PHxUdVMXVR+fj5DhgxBr9c3ei8rKwu9Xo9Op9OwhB2fusZdw+nTp6msrGTcuHEOn8OZTxBTgEghRLIQIsneDpYnjDwhRN7169ftnqRHjx6MGjVKBYguKiYmhsjISJKSkhq9l5ycTFJSEpmZmRiNRg1L2bGpa9w1fPjhhwDtChDObqTOk1LmCyEOA/qGiVJKHaADiIuLa3IhioSEBP7whz9QW1uLt7e3k4uotNfkyZNb3Gf69OksWbLEtn9qaiqpqancuHGDsLAwh/LV6/VkZmYyZcqUTt8Qqq6x0l7jx49n3bp19O/f3+FzOPMJ4r/OOlFCQgKlpaWcPXvWWadUOqi632KTkpLYsGGDhqXpnNQ17pzGjBnD888/365zOPMJQgekCSEigXb9hSUmJgKQm5vLyJEjnVA0xZmOHj3q8P5t/War1+tJTv6qz0NwcDC7d+9m2bJlbTpPR6OusdIexcXFnDt3jvj4ePz8/Bw+j9OeIKSURinly1LKLCllo+qltoiMjCQsLEy1Q3RB+fn5GAwGW0Pp7t27be/l5+cDEBcXh06nU3XkDlLXuPM7ePAgEyZM4Ny5c+06j0cNlLMSQpCQkKACRBcUExPDf//7VW1lWloaQL33MjIy3F6uzkRd487vww8/JCAggBEjRrTrPB43UM4qISGBc+fOqW8wSodk6bGnWnkVTWRnZ5OQkNDuTj4eHSBADZhTOh5LYJgChGhdFqXruX37Nh9//HG7BshZeWyAiI+PRwihqpmUjigOOKF1IZSuKTc3F5PJxLhx4ygrK2PTpk0UFxc7dC6PDRBBQUF84xvfUAFC6VCEEDFAk6sAtWawqKK0x8SJE/nggw8YP348u3fvZsmSJXzyyScOnctjAwSYu7seP34ck8mkdVEUpbUiMT9BxAONZhSQUuqklHFSyrg+fTRb7VTpxLp3786DDz5Iz5490el0DB8+nAcffNChc7UYIIQQS4QQg4QQe4QQP3EoFwclJCRQUlLCp59+6s5sFY3l5+ej0+nQ6/Xo9XpefvllwNxfv+5cQdYODPn5+cTGxmIwGGzn0Ov1TJkyxe2dHCzzkOUBHt1A3ZGvsdK02tpannvuOc6cOcOZM2c4fvw4aWlpLl0P4hSQDiwHBjuUi4OsDdW5ubkMHTrUnVkrGjEajWRkZNTrZnn48GEMBgOZmZm2941GI8uXLycjI4OYmBjS09PJyMiwjQI2Go3ExMRoMl2ElNKIuZHaI3WGa6zY9+9//5uXXnqJ+++/n+zsbPz8/HjiiSccPl9rx0HcBCQQA/zd4dzaaNiwYQQFBZGTk8OTTz7prmyVZixatKjdS8JGR0ezefNmu2l79uwhNja23nvp6em2b6tWwcHB5OV9VdUfGRnJyZMnATAYDERGRnLiRMdsJ1bXWHGUdYK+6OhoFi5cyKxZswgJcbwzXWsCxGggC3gZeM/hnBzg5eXF2LFjVUN1FxcZGdmqaowpU6ag1+sJCQkhMjLSDSXrPNQ17hyys7MJDw8nNzeX27dv2wZBOqq1VUxpaFDFBOZqphdffJHS0lICAgLcnb3SQFPfSp0lJSWF5cuX13vPOldQ3UnkjEZjow+o5ORk0tPTSU9Pd2kZXU1dY8VRH374IePGjWP79u3tapy2am0VUzEaVDGBuSeTyWQiLy+vVVMgKx1bcHAw6enp6HQ627da60Rys2bNsr2fn5/P9u3bAXMDakZGBnFxcQQHBxMTE4NOpyM/Px+j0ajqyBtQ17hzMhqN3Lp1i0GDBrF3715eeeWV9q/yJ6VsdgMeBvYAh4CHWtq/tVtsbKxsjRs3bkhArl+/vlX7K67xySefaF2Edmv4M2Bev8Qpf8+ObA3vgc54jRX3qq2tlenp6dLPz08WFxe36pjm7oPWPEH0klKmAAghHm1fOGq70NBQoqKiVDuEoihKCyoqKti1a1e7G6etWhMgQpt47TYJCQkcPHgQKaVaGF1RFMWOH/zgB9TU1DilcdqqNSOp9ZZBcruBw07JtY0SEhIoKiri0qVLWmSvKIri0crLy8nMzOTYsWNOaZy2svsEIYQYDczG3DAtAOvwyQ2W992q7oC5wYPd3pFKURTFo+Xl5VFTU8OXX37JsmXLnFbT0lwV03op5a26bwghHnZKrm00YsQIevToQW5uLo8//rgWRVDcyDoIS3EddY07F+sAOV9f33aNnG7IbhWTlPJUw+Bgef+I03JuAx8fH+Lj41VDdRdgMBjIysrSuhidmrrGnc8HH3yAl5cXKSkpTmmctnL6kqNCiDRgjzTPR+M0CQkJ/PrXv6ayspLu3bs789SKB8nPz+fEiRO2NZIPHzY3e82aNYv8/HzS0tKYP38+mZmZtnWVAZKSklz2jVgIMQiIxTxQ1ADkSykvuSQzN/DEa6y0T3V1NSaTyWmN01ZODRB1VtLSA04PENXV1eTn5zNu3DhnnlppI3sDFlNSUli4cCHl5eVMmzatUXpqaiqpqancuHGDsLCwJs+dlJSEwWAgJibGNtfPhg0bMBqNHD58mODgYNs3pOXLl7Ny5UrANVUmQojBmKfsLgbyMf9dhwCxQogkQO+qQNFVrrHiHHfu3HFq47RVmwKEEGJQCzdEsytpWZ4u0gAGDhzYlqwZO3YsYG6oVgGi8zMYDISEhBAaau5ZXXek7s2bN22vIyMjCQ4OdtmU01LK7Q3eugVcBFsA6bA85Ror7ZObm8vx48edM3K6gRYDhBBiKTDE8t9YzAuh2NvPupJWTFPnklLqAB1AXFycbEtBw8PDuffee1U7hAc4evRok2k9evRoNr25b7Zg/pAqLi7GYDDYqkLq0uv1GAwGDAYDGzZsQKfTERMTQ1xcXFt+hFaRUl4UQszE/FQ8GHM3b52U8rY13emZWnSVa6y039y5cxFCOLVx2kqYR1o3s4MQD0kp/255Pbipm0IIkYy5WikdOGwJBk2Ki4uTdacSbo3HHnuM7OxsvvjiizYdp7TfuXPnGD58uNbFaJeGP4MQ4qSUstlPPSHETCnlXkugAHOw2GO9J9qj4T3QGa+x4lplZWUEBQURHh7OlStXHDpHc/dBawbKLRBCbBNC/BJ4vamdpBtW0kpISODy5csUFBS4KgtFachg+dsfLKXcK6VcAPTWulCKAvDGG29gMpn4zne+45Lzt6YNYreUci+0PA5CunglrcTERAAOHDjAT37i1tVPlS5KSnkKOCWEmGkJFJJmZhSwdNSIw1zVmi+l1LunpEpX9NprrwHmaTZcocUnCGtwsLzWZByEVVxcHGPHjmXFihVcu3ZNy6IoXYzl6WGFlHJlC9VLKYBBSvky5jVUFMUlzpw5w4ULF/Dy8mLMmDEuyaPFACGEeNgyF9MhIcRDLilFK3l7e/Pmm29SWlrKwoULaan9RHGuioqKDnvNa2tr3ZKPlFInpTQIISL5aooaGyFEmhAiTwiRd/369UbHq2ustNb27dvp1q0ba9aswd/f3yV5ePx03w0NHz6cX/ziF6xYsYLMzExSUlK0LlKXEB4eTkFBAdXV1VoXxWHtGWEqhFgipdxoeb0USzWTlLKpxaPTsfME0VxPvq5+jZXWKysrY+fOnaSkpLB69WqX5aPZdN/nz59vcYW46dOns2TJEsA8cMg6EOjJJ5/kxRdf5Ic//CGbN2/G19fX7vF1Bw4lJyezePFiZsyYwfnz51u1ZGLD/V966SXGjRtHdnY2q1atavH4hvtnZGQwdOhQ9u3bx6ZNm1o8vuH+WVlZhIWFsWPHDnbs2NHi8Q33t3aN3LhxI/v372/x+Lr75+TksHevubZx5cqV5OTkNHtsaGhovf2Li4vR6cwd29LS0rhw4UKzx0dFRdXbPzQ0lPXr1wMwc+ZMiouLmz0+MTGx3v6JiYn1/pYcsLfO66zmurhaevStxzyortWDB4KDg9XKbEqr7Nmzh9u3bzNhwgRqamrw8XH6pBhAB5nuuyEfHx+GDh1KTU0Nn332mdbFUbqGXtYXrQgOK4FMzE8RiuJ0Op2Ovn37smDBAuxVVTpLk+MghBA/kVL+zvI4bX1yGCyldMp0346Mg2ho/fr1rFq1iszMTNuauorSWq0ZB1Fn3yVSyo1CiCAgSUr5Tnvzd8Y9oHQ9H330EaNGjWLkyJHcuXMHg6FRU1ebNHcfNPdcctLyb76195JW0303ZenSpezdu5eFCxcyefLkFkeQKko7hAkhtmGuMmq+fktRXEin0+Hr68vVq1eZMsVlowqAZqqYLP2/wXIzWGa09KjuFT4+PuzYsQOj0cgzzzyjdXGUzu2wlPIp4JfY6Z2kKO5QXl7Ozp07mTp1KkVFRYwfP96l+TUZIIQQgy3VSyuFEEuAWbhwEJyjHnjgAX7+85/zpz/9iXfeafdTv6I0xSCECLKsk+KaTueK0oLdu3dz+/ZtRowYAeDyiUubnYvJMltlDOZpjgGkdaKy9nJm/Wt1dTVjx46loKCATz75xDY7paI0p6U2iLrzkDmS3hLVBqG0VWJiIkajkRMnTvCvf/2LSZMm4e3t3a5zOtoGYZ3N8ibwMOa1qZOAp9pVGhfo1q0bO3bsIDY2lmeeeYY//vGPTp/2Vul6pJR/F0LMx9yDyQiUAJGYq1qNUsrfaVk+pWv56KOPyM3N5de//jUBAQE89JDrxy23pptrGuabwoh50RSPNHLkSH7+85+za9cuVq1a1WFHoyqeRUq5XUq50RIMTgIZdf6vKG6j0+nw8/Pj+9//Ps8//3yLY4mcoTWjK4zATcyN1b1a2FdTzz33HIWFhfzyl7+kvLyczZs3qycJxSksPfhmYX568LinaKVzszZOJycnc+HCBV588UUmT55MVFSUS/NtTYDIk1Kessxk6dGj0ry8vNi6dSv+/v688sorVFRUsG3btnbX0SldlxBiNOYR0VJKuUAI0cs6Rkjrsildh7VxOj09Hb1e79IJ+upqMUBYu7tKKVdYurp6NCEEmzZtokePHrz44otUVFTw5ptvumwoutK5Wf/+hRDzLU+jJUBvIUR0M/MwKYpT6XQ6hg0bxoMPPsjatWsZMWIEQUFBLs/X7qdmg8fpEMxVTIJmlhz1JEIIXnjhBfz9/Xn++eepqKjg7bffbnLOJkVphZtAsKXhOkRKeVoIMbrOeCFFcYm6jdO1tbXk5ua6ZHlRe5r6Wp1Xd/S0p46kbslzzz1Hz549+d///V8effRRsrKy6N69u9bFUjog67KjQoj1QIblPRUcFJezNk4/8cQTfP755wghXD7+wcpugLAMBrKKFEJYVzUf7PoiOdeiRYvw9/dnwYIFzJgxgz//+c/07NlT62IpHZBl8ay9Le6oKE5St3E6NDSU0NBQSkpK3Lb2Rmsq5vXAy5irmza4tjiukZ6ejr+/Pz/+8Y/55je/ydtvv83Xv/51rYulKIrSpMLCQn7zm9/YGqetvL293dbxpskAUaenRjJfzWm/AbA7m6unr8X7xBNPEBQUxI9//GNGjRrFpk2bSE9PV91gFUXxCDU1NRw/fpx3332Xffv28fHHHwMQERHBkSNHmDBhAtOmTSMwMJCYmBiWLzevR7VixQouX76MyWTCZDJRW1tre20ymZgzZw6zZzs2CXdzTxDWOQBOWqcTaKENIgXQSyn1QojDfDU9h8f43ve+R1xcHHPnzuWpp57ir3/9K2+88Qbh4eFaF01R6iksLKR79+5qhbZOrqioiD/96U9kZWVx4sQJKisrG+1jNBr585//zPz58zlw4ADR0dFUVFTY0k+fPs1nn32Gt7c3Xl5eeHl51Xt9+7bjsyM11YtpNDBbCDHb/F8xBXMvpsHAEXvHWJZSpKm1eD1FREQEBw8eZOvWrSxbtowHHniA119/nVmzZmldNEWxefrpp/nggw946aWXmDdvnhrL00mYTCZ27tzJ73//eyoqKsjLy6s360N4eDgJCQmMHTuW0aNHM3r0aPr06QNAZmYmYF5psu4YiIMHD7quwFLKRhswGvNa1A3f/669/RvsswFzd0B7aWmYn0zyBg4cKLX2n//8R8bHx0tAzpkzR5aUlGhdJMWNMPfWa/bvua0bEAwsw1w1G9PcvrGxsU2W7aOPPpITJ06UgIyLi5O5ubnO/vEVFzIajTInJ0e+8cYbcv78+XLo0KHyu9/9ruzTp4/E3J4rY2Nj5bp16+SuXbvk+++/L2/fvt3sORctWiT9/f1lVVWVU8va3H3Qmj/4IOBRy7athX2TLTdIZEvnbe7mcKeqqiq5Zs0a6e3tLSMiIuTBgwe1LpLiJi4KEMusX5CADc3t29I9YDKZ5Ntvvy379+8vATl37lxZVFTkxCugOENhYaHMyMiQP/vZz+SUKVNkeHi4LQjU3Xr16iXnzJkjf//738uCgoI25zNmzBg5ceJEp5e/vQFiKbAE84yu85vZLxnzZGaHW7oxpAcFCKsTJ07IYcOGSUDOnj3boV+g0rG4KEBk2nttb2vtPXD79m25dOlS6ePjI4ODg+Wrr74qq6ur238BlFa7fv26vHr1qpRSyoKCAvnQQw/J/fv3Syml/Mtf/iIB6ePjI/39/W0BwdvbW44ePVquWrVK5uTkyJqaGofzN5lM8oknnpCbNm1yys9TV3sDxHxgJhANPNrS/q3dPC1ASCllZWWlXLdunfTz85OBgYFyy5Yt7fqlKp7NDQHisJ10h6tZz507J5OSkiQgx44dq/42nayqqkqeP39e7tu3T27atEmmpaXJSZMm2aqFlixZIqWU8sKFCzIyMlJ+61vfsn2pBKS/v7986KGH5Jo1a+SRI0dkWVmZxj9R6zR3HzS7YBCAEGKilPJ962R90kmTlHnyYimfffYZTz/9NO+99x6xsbG8/vrrxMW1am17pQNpacEgB8+5DMiSUhqEEBlSyvSm9nXkHpBS8s4771BYWMgzzzyDlJKbN2+qRbLaKDs7GyEEiYmJ1NbWMmLECD799FNqamps+4SGhhIVFUVERARCCMrKyjh79iyXLl0CoFevXjz44INMmDCBiRMnEhsb67LpfMrLy/H393dJt/zm7oPWBIjXgdelkycm8+QAAeYbMTMzk0WLFnH16lWefvppXnjhBXr18ugZz5U2cFGACMb8lGAADFLKJtdQccY9sG/fPubMmcM//vEPYmNj23UuZ9Lr9fzsZz+jrKxMszLU1NRw9+5dampqqK6uRghhC6SFhYX4+PjQt29fAIqLi/H29sbHx4du3brRrVs3vLy8KC0tpbi4GIC+ffvagsHEiRMZMWKE23qXTZ8+nfLycv7+d4cXMGySwyvKAUgpF1hO8hNgiJRypZPL55GEEKSkpPDII4+wevVqXnvtNbKysli6dCnz5s1TgUKxS0ppxDzzgFvcf//9/OhHP2LkyJEA3Lhxg7CwMHdlb9f+/ftJTk5m0KBBTJ482W351tbWcu3aNa5cuUJBQUG9/v89e/YkLCyMSZMmAVBSUoKvr2+L0+74+voSHx/PxIkTiYqK0mRgrclkIicnh+9///tuz7vFACGE2GZ5+V/gl64tjufp1asXW7Zs4cknn2Tx4sUsXryYNWvWMG/ePH76058yeHCHm55K6USGDBnCa6+9BmBbzH7ChAls2rSJAQMG2PaTUlJWVsatW7cICQnB39+fy5cvk52dbRud+49//IN33nmHW7ducevWLSoqKhqNyjWZTOzatYuBAweyc+dONm/ezAcffECPHj14+eWX2bJlCwUFBfj5+VFbW8udO3eIi4sjNjaW2NhYl1WFbdiwgbVr11JeXo6fnx+TJk1i6tSpfPOb3yQqKgp/f3+X5OsO58+f5+bNm4wfP97tebdmydF8KeVT0rzM4q2Wd++cYmNjOXr0KHl5efzP//wPv/3tb/n617/OzJkz+fDDD2mpqk5RXM3X15eFCxeyb98+hg0bxqhRoxg0aBC9e/fGx8eHwMBAIiIiyM7OBiAnJ4fHHnuML774AoCzZ8/y1ltv8f7773Pp0iXu3LnD3bt3MZlMeHl54evrS48ePWzfogMCAujfv78t/4sXL1JYWEhYWBgzZsxgxIgRnDlzhlWrVvHII48QFhbGkCFDqKqqAsxtfTdu3HDoZ/3nP//J8OHD+fzzzwG47777SE1NZf/+/RQXF3Po0CEWLVrEqFGjOnRwAGy/L3fN4FpXi20QruLpbRAtKSgo4LXXXuP111+npKSE+Ph4nnnmGaZNm6YaDDsIV7RBtIWr7oFLly6xdu1aiouL6dWrV6Nt2rRpREREYDQaKSwsZMiQIfj5+bUrzzfffJN58+YxadIk9u3bR0BAgC2tpKSEU6dOcfLkSQoLC3nllVcA+Pa3v82XX37JmTNnAHjrrbcICwsjNjbWNnrY+vMcOHCAAwcOkJqayqOPPsr58+f56U9/yq9+9Stb9VpnNXfuXP76179y/fp1tzdSO7WLX1s2T+zm6ojS0lK5detWGRUVJQEphJBjxoyRq1evlseOHVP91T0YLujm2pats9wDW7dulYD81re+1aaunceOHZPvvvuulNLczz84ONjWZXTAgAFyxowZcvjw4bb3Bg8eLP/whz+46sfwWH/729/ktm3bXHb+5u4D9QThJCaTiePHj3Po0CHee+89jh8/jslkIigoiIcffphHHnmE+Ph4wsLCCA0Nrfeormijsz5BuNMrr7zCs88+y4wZM9izZ0+7FuQyGo22J42TJ09y+vRpBgwYwLRp05g6dapmjcSdXbu6ubpKZ7g5mlNSUsKRI0d47733OHTokK2e18rPz88WLEJDQwkLC6Nfv37cc8893HPPPfTv39/2uu7juuI8KkC0z/r161m1ahUzZ85US/q6yMWLFzEajYwaNQovr9Y0Gbddu7q5Ko7p3bs3ycnJJCcnI6Xk/Pnz/Oc//6G4uJgbN25QXFxc7/WZM2c4dOiQ3al5g4KCiIiI4N5772Xw4MG2bdCgQQwePJjevXurb1aK20gpWbNmDWvXrmXOnDns2LEDHx/P+SgpLS1l37597N69m3//+98AzJkzh1/84hdUV1czfPjwRsekp6ezdOlSjEaj3UGxzz77LAsXLqSwsJCJEyc2Sl+9ejVPPvkkn376KVOnTm2Uvn79embNmsWpU6fszhy9ZcsWpk2bxrFjx0hNTbW9bzQauX37Nrdu3dKksd1zfqudmBCCYcOGMWzYsBb3LS0tpaCggIKCAgoLC22vr1y5wqVLl8jJycFoNNY7JigoiIEDB9KvXz++9rWv0bdvX772ta812vr06aO+5SntIqVk+fLl/OpXv2LevHlkZGR4xFTkUkqEEEgpuf/++7l8+TL9+/dn4sSJeHt727qjCyFISEhodLy1S7CPj4/d9HvuuQcw9xSzl96vXz8A/P397aZbB+QFBgbaTbd2bOnVq1ej9DFjxmjWE0tVMXVARqORS5cucfHiRdt2+fJlrl27ZtvKy8vtHtu7d2/69u1r26wBpU+fPo220NBQj7j5XUVVMbWNyWRi0aJFvPrqqyxcuJBXX33VZdUerXH37l0OHTrE7t27OXv2LKdOnUIIwa5du4iIiGD8+PGalq+jUFVMnUxwcDDR0dFER0c3uU9paSlFRUW2gGF9XVRUZNvOnTvH0aNHbVMJNGSdmiA0NBR/f398fX3x8/Oz/Wt93bNnT0JCQggJCbHtb91CQkLo1q0blZWVtu3u3bu21xUVFZSVlXHnzh1KS0spLS21vb5z5w5VVVX06dOHfv361dvCw8Pp27eveiJyE5PJxIIFC9i+fTvPPvssGzdu1Kxa89SpU2zZsoX/+7//sw38e/TRRykrKyMgIIDHH39ck3J1RipAdFIBAQEEBAQQGRnZ4r41NTUUFxdTVFTE9evXG23FxcVUVlZSVVXF3bt3qaio4NatW9y9e5eqqipKS0u5efNmvWUQ28Pf35/AwEACAgLo1q2brZ3GnuDgYHr37m3b6v7fGkinTZvmlHJ1VTU1NcydO5edO3fy3HPPsW7dOrcGh9raWv75z39y3333MWDAAC5dusQ777zD9773PR577DGSkpLo1q2b28rTlagAoeDj42Nrp2iPiooKW+N73c1kMtG9e3e7m5+fny2YBQYG0rNnT7vVWlVVVVy7do2rV69y9epVvvzyS65evUpRURFGo5GSkhLbwC/r/ysrK/nBD37QIQPEJ598wsWLFxu9P23aNIQQfPzxx416xnl7e/Ptb38bMH/LLiwsrJfu5+dHUlISACdOnKCoqKhees+ePW1zJ+Xm5lJcXExNTQ0bN27k2LFjzJs3jxdeeAGAY8eOcetW/SFrJtEAAAd3SURBVIkVQkJCSExMBODo0aONJurr27cv8fHxgHkyv7t379ZL79+/P6NHjwbg0KFDVFRUoNfrycrK4tq1a6xdu5bVq1czffp0rl271q4utUorNTVAwtVbZxkkpHiuioqKZpdxxIMHyi1evNjuqmQmk0lKKWVaWlqjtJ49e9qOf/zxxxul9+vXz5Y+Y8aMRun33XefLX3y5MmN0kePHm1Lj42NbZRed7Uz68DRutt3vvMdW7q9Vdcee+wxW3pgYKAEZPfu3eXMmTNlZmZmh1lfoaNp7j5QjdRKl+WpjdR3797l3Llzjb7hA0RHRyOE4IsvvmhU7ebl5cWoUaOAr/rP1+Xj48OIESMAMBgMjZ4AfH19+cY3vgHAuXPnWLFiBTk5OSxdupRZs2bRo0cPW/rZs2cbdYQIDAy09dT7+OOPqaysrJfeq1cvoqKiADh9+jTV1dX10kNCQhgyZAgAJ0+exGQyMWzYMAIDAxtdB8V5PLKR+vz5842mAk5JSWHhwoWUl5fbrRZITU0lNTWVGzdukJyc3Cj9qaeeYvbs2Vy+fJkf/ehHjdIXL17MjBkzOH/+POnpjddxef7550lKSuL06dMsWrSoUfpLL73EuHHjyM7OZtWqVY3SN2/eTHR0NHq93vYoXldGRgZDhw5l3759bNq0qVH6zp07GTBgALt372bbtm2N0rOysggLC2PHjh3s2LGjUfq7775Ljx492Lp1K3v27GmUfvToUQA2btzI/v3766X5+/tz4MABANatW8eRI0fqpYeGhrJ3714AVq5cSU5OTr30iIgI3nrrLQAWLVrE6dP1lw+JiopCp9MBkJaWxoULF+qlR0dHs3nzZgB++MMfcuXKlXrpiYmJrF+/HoCZM2c2+nB8+OGHWb16NQBTp05t1B4yffp0lixZAuDWKagdcfjwYWbMmKF1MRBC8Lvf/Y558+Y1SrMGiqZYA1FTmutgAXjU2hZdmWqDUBQP88ADD/Db3/5W62IwcuRIJkyYoHUxFA05rYqpLStpgapiUrTnqVVMiuJOzd0HzhxFkgbopJRZwGwnnldRFEXRgDMDRLw0L7cI0HLne0VRFMWjuaoNItjem0KINMxPGgClQojzLsq/PcIAx5a5cj1PLZunlguaL9u97ixIQydPnrwhhPi8iWRPuKaqDNrn744yNHkfODNAnBBCREopDZjbIRqRUuoAnRPzdDohRJ6W9dLN8dSyeWq5wLPLJqXs01SaJ5RblUH7/LUugzMDhA5IE0IYgAwnnldRFEXRgNMChKX94WVnnU9RFEXRlpoLtzFPrgLz1LJ5arnAs8vWHE8otyqD9vmDhmXQbKoNRVEUxbOpJwgPJYQIFkIkCSGW1fn/MiFEshAixsPKFimEOCmEyBBCqC7OitJJqKk2LCwfbJlAHrDB0htLM1JKoxAiD7AGA+tARKMQYgPQ7Eh1N5cN4OE642A0YxnRH4e5bPmYf5+tHuGvtbbOSODCMtiuoZRS7+4y1ClLGrBHi78tS94GINgyANjtLF8GQwC0+D2oAFGfR3zINSFeSmntBOCJ39JTLIvI5Gn8IZwC6KWUeiHEYeAwHhJYW8kTvgg0vIaaBAhLoJpiyd+t96UQIhlzgNYyOCYBNy2/h7QWD3ABVcVUX4oQIk3rKpxWsDsQUStSSoOUUmcZ59J4mlz3lkUnpTRYnggNdLwR/pqX18411EoccEKjvKcAkZYq3SQtCmAJTtuFEBlA4+mZ3UAFCAtP+pBrwok69fuaVn81ZAmq1qAVomlhvpIOLG/wnkcF1lbQurz2rqFbWL6kaT2TYZ6laknLazAf+C+wUosyqABh4aEfcinAFEtg0AHJlkdfTxiIWLdse4A4y2OwJjdTXZZrtB7z79FjA2sTPKK8Da6hFiIxP0HEA1p8g/+vBnk2lCSlzLdULdtflN3FVDdXizoNc5GY6187woeJ0oDlg20lcBNz/f16VCN1W8tQ7xpKKbX6Bh2MueNIpuXJ3t15W38PRi3aIixPENZqvhBNyqAChKIoimKPqmJSFEVR7FIBQlEURbFLBQhFURTFLhUgNCSEiLEOgHF0igrrcXXPpSiK4gyqkdoDWD7kk+uMlHbpcYrSWVgHtXp677SOSk21oSHLCM0YLCN+LX/sRr7q9623pE+x/D8Dc1fcYMzjImLqHBcCxEgpX7YEjiTM3RSNlv1nW46PUQFF6SxUYHAtVcWkIUu/5lDMgeCE5Y99A+YRpHlYxmRg7oedbkm/aTk8qe5xdc4F5skGdZZRoOmWf29a9hnirp9PUVzJUq26QetydGYqQHiQuiNoLcHAOtVAsSV9A+anjfwmjrPHOjrcUychVBRHGdB+OpJOTQUIDVmqmCItk7OFYn5iWI55bW9rNVMS5ukGwDz8P9KyWaudQjFPKpbEV5O7LbdMHZIMbLBWZVkCSZxas0HpRIx1pshRnEw1UiuK0iFZO2kAWWpqHNdQAUJRFEWxS1UxKYqiKHapAKEoiqLYpQKEoiiKYpcKEIqiKIpdKkAoiqIodqkAoSiKotilAoSiKIpi1/8DBnWXK5pdZ7YAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 394.92x129.6 with 2 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "latexify(5.485, 1.8)\n",
    "fig, ax = plt.subplots(1, 2)\n",
    "\n",
    "ax[0].axhline(lstsq_val_loss, linestyle='-.', c='black', label='LR')\n",
    "ax[0].plot(np.arange(1, 21), val_losses, c='k', label='COM')\n",
    "ax[0].axhline(bayes_val_loss, linestyle='--', c='black', label='true')\n",
    "ax[0].set_xlabel(\"iteration\")\n",
    "ax[0].set_ylabel(\"validation loss\")\n",
    "ax[0].set_ylim(0)\n",
    "ax[0].legend()\n",
    "\n",
    "ax[1].plot(Xval[13] @ theta_lstsq, '-.', c='k', label='LR')\n",
    "ax[1].plot(layer(Xval[13] @ theta)[0].detach().numpy(), c='k', label='COM')\n",
    "ax[1].plot(Yval[13].numpy(), '--', c='k', label='true')\n",
    "ax[1].legend()\n",
    "ax[1].set_xlabel('$i$')\n",
    "ax[1].set_ylabel(\"$\\phi(x;\\\\theta)$\")\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.savefig(\"figures/mono_regression.pdf\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "3.3753725951890483"
      ]
     },
     "execution_count": 33,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "lstsq_val_loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1.5115945280018195\n"
     ]
    }
   ],
   "source": [
    "print(loss(X, Y, theta_lstsq).item())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.5620522413492108"
      ]
     },
     "execution_count": 31,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "val_losses[-1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.2637994455876921"
      ]
     },
     "execution_count": 32,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "bayes_val_loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "@webio": {
   "lastCommId": null,
   "lastKernelId": null
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
